For those students taking the course for credit, the work done during this workshop is to be handed in. Please e-mail both your Rmd source file and html output to hst953hw@mit.edu no later than Friday Nov 4, 2016.
To complete the assignment, fill in necessary code in the places indicated with # Students: Insert your code here and text based answers ### Student Answer
Before beginning, please test to see if the Rmd file will compile on your system by clicking the “Knit HTML button” in R studio above.
Prediction usually refers to using a statistical model to determine the expectation of different outcomes for patient with a set of covariate and confounder values. For example, let’s load the aline dataset from the previous workshop.
dat <- read.csv("aline-dataset.csv")
library(plyr);library(Hmisc)
## Loading required package: lattice
## Loading required package: survival
## Loading required package: Formula
## Loading required package: ggplot2
##
## Attaching package: 'Hmisc'
## The following objects are masked from 'package:plyr':
##
## is.discrete, summarize
## The following objects are masked from 'package:base':
##
## format.pval, round.POSIXt, trunc.POSIXt, units
dat$sofa_cat <- cut2(dat$sofa_first,c(0,4,7))
dat$age.cat <- cut2(dat$age,c(50,60,70,80))
dat$service_unit2 <- as.character(dat$service_unit)
dat$service_unit2[dat$service_unit2 %in% names(which(table(dat$service_unit)<200))] <- "Other"
Let’s fit the full model we considered for 28 day mortality in the previous workshop.
full.model.glm <- glm(day_28_flg ~ aline_flg + age.cat + sofa_cat + service_unit2 + renal_flg + chf_flg + cad_flg + stroke_flg + mal_flg + resp_flg,data=dat,family="binomial") #Note: used service_unit instead of service_unit2
summary(full.model.glm)
##
## Call:
## glm(formula = day_28_flg ~ aline_flg + age.cat + sofa_cat + service_unit2 +
## renal_flg + chf_flg + cad_flg + stroke_flg + mal_flg + resp_flg,
## family = "binomial", data = dat)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -1.9002 -0.5027 -0.2906 -0.2048 2.9417
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) -3.80050 0.22133 -17.171 < 2e-16 ***
## aline_flg -0.02825 0.13478 -0.210 0.83400
## age.cat[ 50, 60) 0.70631 0.24282 2.909 0.00363 **
## age.cat[ 60, 70) 1.37080 0.22770 6.020 1.74e-09 ***
## age.cat[ 70, 80) 1.81831 0.21729 8.368 < 2e-16 ***
## age.cat[ 80,300] 2.58962 0.21106 12.270 < 2e-16 ***
## sofa_cat[ 4, 7) 0.25631 0.14661 1.748 0.08043 .
## sofa_cat[ 7,14] 0.51149 0.19158 2.670 0.00759 **
## service_unit2NMED 0.17722 0.21414 0.828 0.40792
## service_unit2NSURG 0.15114 0.21851 0.692 0.48915
## service_unit2Other -0.33926 0.23511 -1.443 0.14903
## service_unit2SURG -1.39823 0.39424 -3.547 0.00039 ***
## service_unit2TRAUM -0.02558 0.23906 -0.107 0.91480
## renal_flg 0.03357 0.24058 0.140 0.88901
## chf_flg -0.96178 0.24702 -3.894 9.88e-05 ***
## cad_flg -0.43435 0.18891 -2.299 0.02149 *
## stroke_flg 1.74596 0.17918 9.744 < 2e-16 ***
## mal_flg 0.79214 0.16578 4.778 1.77e-06 ***
## resp_flg 0.65706 0.13896 4.728 2.26e-06 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 2340.8 on 2750 degrees of freedom
## Residual deviance: 1737.2 on 2732 degrees of freedom
## AIC: 1775.2
##
## Number of Fisher Scoring iterations: 6
Imagine we have a patient just like the patient in hadm_id=137140:
new.pt <- subset(dat,hadm_id==137140)
new.pt
## subject_id hadm_id icustay_id age gender_num icustay_intime
## 123 29040 137140 204010 85.56538 1 2143-06-16 01:43:38
## day_icu_intime_num hour_icu_intime icustay_outtime icu_los_day
## 123 0 1 2143-06-17 04:22:20 1.110208
## hospital_los_day hosp_exp_flg icu_exp_flg mort_day day_28_flg
## 123 2.675694 0 0 NA 0
## mort_day_censored censor_flg aline_flg aline_time_day weight_first
## 123 150 1 0 NA 77.3
## height_first bmi service_unit sofa_first map_first hr_first
## 123 1.7526 25.16598 MED 1 67 77
## temp_first spo2_first cvp_first bun_first creatinine_first
## 123 36.16667 92 NA 24 1
## chloride_first hgb_first platelet_first potassium_first sodium_first
## 123 100 15.4 248 4 140
## tco2_first wbc_first chf_flg afib_flg renal_flg liver_flg copd_flg
## 123 29 11.7 0 1 0 0 0
## cad_flg stroke_flg mal_flg resp_flg endocarditis_flg ards_flg
## 123 1 0 0 0 0 0
## pneumonia_flg sofa_cat age.cat service_unit2
## 123 0 [ 0, 4) [ 80,300] MED
and we’d like to know the patient’s chance of survival at 28 days, we can estimate this by using the predict function, which takes
predict(full.model.glm,newdata=new.pt,type="response")
## 123
## 0.1617551
Based on this model, we would predict that the probability of death for this patient would be around 0.16, meaning that they have about a 84% chance of surviving, but how good is this prediction? We can do predictions for every patient in the dataset and add it as a new column in the dat data frame, and then plot the distribution of these predictions.:
dat$logRegPred <- predict(full.model.glm,newdata=dat,type="response")
hist(dat$logRegPred,breaks=11)
Here we see that the model predicts that most patients have a fairly low risk of dying, while some patients have a very high risk of dying. One way of looking at how good this prediction is, is by looking at the accuracy (how often we would predict the right outcome). To do this we need to specify a cutoff above which we make a binary prediction that the patient is likely to die, and below which we predict that the patients will live. Let’s say, we set this cut-off at 0.5:
dat$logRegPred0.5 <- dat$logRegPred>0.5
Then use the table function to see how the outcomes are distributed across our two predictions:
predTab1 <- table(dat$logRegPred0.5,dat$day_28_flg==1,dnn=c("Prediction","Death by 28 Days"))
predTab1
## Death by 28 Days
## Prediction FALSE TRUE
## FALSE 2266 290
## TRUE 68 127
The accuracy is the times where our prediction matched the actual outcome. This corresponds to the diagonal elements of the 2x2 table. In our case, 0.8698655. Is this good?
That can be a complicated question.
0.5 is the “right” threshold.Let’s tackle each of these in reverse order.
When we train a model using glm or any other algorithm, were are optimizing the performance for this training dataset, and it’s unlikely that the performance will be as rosy when it is applied to unseen data.
Let’s explore this a little further. We will divide the dat data frame into two datasets: datTrain and datTest, by randomly selecting ICU stays to be in each, so that about 50% of our data is in the training dataset and about 50% is contained in the testing dataset.
We first set a seed which makes any random selections reproducible. Then we use the createDataPartition function to sample indexes in the dat data frame to include in datTrain. Then we use these indexes to establish training (datTrain) and testing (datTest) datasets.
set.seed(4441) # We do this so it's reproducible!
library(caret)
##
## Attaching package: 'caret'
## The following object is masked from 'package:survival':
##
## cluster
trainIdx <- createDataPartition(dat$day_28_flg,p=0.5)$Resample1
datTrain <- dat[trainIdx,]
datTest <- dat[-trainIdx,]
Repeating the glm fit we looked at in the previous section, but use only the datTrain observations:
train.glm <- glm(day_28_flg ~ aline_flg + age.cat + sofa_cat + service_unit2 + renal_flg + chf_flg + cad_flg + stroke_flg + mal_flg + resp_flg,data=datTrain,family="binomial")
summary(train.glm)
##
## Call:
## glm(formula = day_28_flg ~ aline_flg + age.cat + sofa_cat + service_unit2 +
## renal_flg + chf_flg + cad_flg + stroke_flg + mal_flg + resp_flg,
## family = "binomial", data = datTrain)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -1.9626 -0.4711 -0.2895 -0.1917 3.0817
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) -3.76468 0.31776 -11.848 < 2e-16 ***
## aline_flg 0.10727 0.19828 0.541 0.588531
## age.cat[ 50, 60) 0.44860 0.36252 1.237 0.215924
## age.cat[ 60, 70) 1.28783 0.32516 3.961 7.47e-05 ***
## age.cat[ 70, 80) 1.69481 0.31116 5.447 5.13e-08 ***
## age.cat[ 80,300] 2.68516 0.29402 9.133 < 2e-16 ***
## sofa_cat[ 4, 7) 0.02339 0.21900 0.107 0.914945
## sofa_cat[ 7,14] 0.35797 0.28937 1.237 0.216061
## service_unit2NMED 0.50722 0.29642 1.711 0.087055 .
## service_unit2NSURG 0.14043 0.31686 0.443 0.657622
## service_unit2Other -0.61614 0.36892 -1.670 0.094893 .
## service_unit2SURG -2.07754 0.74940 -2.772 0.005566 **
## service_unit2TRAUM -0.22322 0.34613 -0.645 0.518976
## renal_flg 0.03382 0.32362 0.105 0.916763
## chf_flg -0.83077 0.33109 -2.509 0.012100 *
## cad_flg -0.39529 0.27705 -1.427 0.153648
## stroke_flg 1.45182 0.25079 5.789 7.08e-09 ***
## mal_flg 0.71311 0.23829 2.993 0.002766 **
## resp_flg 0.75829 0.20117 3.769 0.000164 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 1137.31 on 1375 degrees of freedom
## Residual deviance: 833.12 on 1357 degrees of freedom
## AIC: 871.12
##
## Number of Fisher Scoring iterations: 6
Let’s create predictions for this model in the training dataset:
datTrain$logRegPred <- predict(train.glm,newdata=datTrain,type="response")
datTrain$logRegPred0.5 <- datTrain$logRegPred>0.5
predTabTr <- table(datTrain$logRegPred0.5,datTrain$day_28_flg==1,dnn=c("Prediction","Death by 28 Days"))
predTabTr
## Death by 28 Days
## Prediction FALSE TRUE
## FALSE 1145 144
## TRUE 32 55
Accuracy is pretty similar to that when using the whole dataset: 0.872093.
Now, let’s try it on the testing dataset. This is a dataset which was not used to build the model on, so it’s a better assessment of the performance of the prediction algorithm:
datTest$logRegPred <- predict(train.glm,newdata=datTest,type="response")
datTest$logRegPred0.5 <- datTest$logRegPred>0.5
predTabTe <- table(datTest$logRegPred0.5,datTest$day_28_flg==1,dnn=c("Prediction","Death by 28 Days"))
predTabTe
## Death by 28 Days
## Prediction FALSE TRUE
## FALSE 1122 156
## TRUE 35 62
Here the accuracy is about 0.8610909, which is slightly less than that in the training dataset. This discrepancy between training and test set performance is not that large, but in many circumstances, it can be very significant.
In the previous section, we again used 0.5 as a threshold, but this was completely arbitrary, and may not be a good cutoff for our particular application.
Looking back at predTabTe, if look at those who survived, we correctly predicted their survival in 1122 cases, or about 97% of cases. As for those who died, we only correctly predicted their survival in 28.44% of cases.
Our accuracy measure is only good, because we have far more survivors in this population of patients than deaths. These two quantities we calculated are often called the specificity (for survivors, % of time we correctly predict them as survivors, aka true negative rate), and the sensitivity (for deaths, % of time we correctly identified them as patients who died, aka true positive rate).
We of course picked 0.5 as a cutoff, but this is arbitrary. One way of getting around the arbitrariness of this, is to evaluate all potential cutoffs through an Receiver Operator Characteristic Curve (ROC curve).
ROC curves plot 1-specificity vs. the sensitivity of the algorithm for predicting the outcome, while varying the cutoffs used to define the predictions. Evaluation is usually done with the area under the curve, with AUC of 1 indicating perfect prediction, and an AUC of 0.5, being no better than “flipping a coin”.
For our training and test datasets, the ROC curves and AUCs are presented below.
library(ROCR)
## Loading required package: gplots
##
## Attaching package: 'gplots'
## The following object is masked from 'package:stats':
##
## lowess
predTr <- prediction(datTrain$logRegPred,datTrain$day_28_flg)
perfTr <- performance(predTr,"tpr","fpr")
plot(perfTr)
text(0.6,0.2,paste0("AUC: ", round(performance(predTr,"auc")@y.values[[1]],3)))
predTe <- prediction(datTest$logRegPred,datTest$day_28_flg)
perfTe <- performance(predTe,"tpr","fpr")
lines(perfTe@x.values[[1]],perfTe@y.values[[1]],col='red')
text(0.6,0.1,paste0("AUC: ", round(performance(predTe,"auc")@y.values[[1]],3)),col='red')
Here we can see a few important things:
ROC curves tell us about discrimination (how well we are able to distinguish between survivors and deaths), but an equally important aspect is the calibration of our model. For example, if we say that a patient has a 99% chance of dying, and while this patient is at higher risk of dying than the average patient, the actual risk is far less than 99% (e.g., 20%), then our model is not calibrated.
There are qualitative and quantitative assessments of calibration. A qualitative assessment can be done using the calibrate.plot function in the gbm package:
#install.packages("gbm") # if this chunk fails, install gbm package
prop.table(table(datTrain$day_28_flg,cut2(datTrain$logRegPred,seq(0,1,0.1))),2)
##
## [0.0,0.1) [0.1,0.2) [0.2,0.3) [0.3,0.4) [0.4,0.5) [0.5,0.6)
## 0 0.96466431 0.81666667 0.74814815 0.68055556 0.54716981 0.64285714
## 1 0.03533569 0.18333333 0.25185185 0.31944444 0.45283019 0.35714286
##
## [0.6,0.7) [0.7,0.8) [0.8,0.9) [0.9,1.0]
## 0 0.35000000 0.10714286 0.40000000 0.00000000
## 1 0.65000000 0.89285714 0.60000000 1.00000000
gbm::calibrate.plot(datTrain$day_28_flg,datTrain$logRegPred)
prop.table(table(datTest$day_28_flg,cut2(datTest$logRegPred,seq(0,1,0.1))),2)
##
## [0.0,0.1) [0.1,0.2) [0.2,0.3) [0.3,0.4) [0.4,0.5) [0.5,0.6)
## 0 0.94945848 0.84375000 0.76562500 0.58823529 0.55932203 0.42857143
## 1 0.05054152 0.15625000 0.23437500 0.41176471 0.44067797 0.57142857
##
## [0.6,0.7) [0.7,0.8) [0.8,0.9) [0.9,1.0]
## 0 0.42105263 0.26470588 0.37500000 0.00000000
## 1 0.57894737 0.73529412 0.62500000 1.00000000
gbm::calibrate.plot(datTest$day_28_flg,datTest$logRegPred)
More formal testing can be done using the Hosmer-Lemeshow test (see ?hoslem_gof in the sjstats package). Beware: the null hypothesis is that the model fits the observed data!
- Moving towards a topic which we will begin next week, sometimes in observational data, we want to model who gets the treatment (or more generally who gets exposed). With this in mind, you might find it useful to build a model to predict treatment. Build a ‘full model’ with the variables we used above, using
aline_flgas the response (‘outcome’) and all other variables as predictors (covariates), except do not includealine_flgorday_28_flgas predictors. Report the accuracy of this test using the 0.5 cutoff, and give an estimate of the accuracy of a ‘baseline’ model (one which uses no covariates).
- Repeat part a), but use the training set we defined above to fit the model. Evaluate the accuracy of the model as in a) in both the training and test sets.
- Plot an ROC curve for the training and test sets.
- Assess the calibration of the model in both the training and test sets.
# a
## full model
full.glm <- glm(aline_flg ~ age.cat + sofa_cat + service_unit2 + renal_flg + chf_flg + cad_flg + stroke_flg + mal_flg + resp_flg, data=dat, family="binomial")
dat$logRegPred <- predict(full.glm, newdata=dat, type="response")
dat$logRegPred0.5 <- dat$logRegPred>0.5
(predTab <- table(dat$logRegPred0.5, dat$aline_flg==1,
dnn=c("Prediction","Aline")))
## Aline
## Prediction FALSE TRUE
## FALSE 857 362
## TRUE 501 1031
confusionMatrix(dat$logRegPred0.5, dat$aline_flg==1)
## Confusion Matrix and Statistics
##
## Reference
## Prediction FALSE TRUE
## FALSE 857 362
## TRUE 501 1031
##
## Accuracy : 0.6863
## 95% CI : (0.6686, 0.7036)
## No Information Rate : 0.5064
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.3717
## Mcnemar's Test P-Value : 2.633e-06
##
## Sensitivity : 0.6311
## Specificity : 0.7401
## Pos Pred Value : 0.7030
## Neg Pred Value : 0.6730
## Prevalence : 0.4936
## Detection Rate : 0.3115
## Detection Prevalence : 0.4431
## Balanced Accuracy : 0.6856
##
## 'Positive' Class : FALSE
##
## baseline
base.glm <- glm(aline_flg ~ 1, data=dat, family="binomial")
dat$logRegBasePred <- predict(base.glm, newdata=dat, type="response")
dat$logRegBasePred0.5 <- dat$logRegBasePred>0.5
(predTab <- table(dat$logRegBasePred0.5, dat$aline_flg==1,
dnn=c("Prediction","Aline")))
## Aline
## Prediction FALSE TRUE
## TRUE 1358 1393
predTab[1] / sum(predTab)
## [1] 0.4936387
# b
set.seed(777)
trainIdx <- createDataPartition(dat$aline_flg, p=0.5)$Resample1
datTrain <- dat[trainIdx,]
datTest <- dat[-trainIdx,]
## full model
train.glm <- glm(aline_flg ~ age.cat + sofa_cat + service_unit2 + renal_flg + chf_flg + cad_flg + stroke_flg + mal_flg + resp_flg, data=datTrain, family="binomial")
datTrain$logRegPred <- predict(train.glm, newdata=datTrain, type="response")
datTrain$logRegPred0.5 <- datTrain$logRegPred>0.5
(predTab <- table(datTrain$logRegPred0.5, datTrain$aline_flg==1,
dnn=c("Prediction","Aline")))
## Aline
## Prediction FALSE TRUE
## FALSE 436 187
## TRUE 253 500
confusionMatrix(datTrain$logRegPred0.5, datTrain$aline_flg==1)
## Confusion Matrix and Statistics
##
## Reference
## Prediction FALSE TRUE
## FALSE 436 187
## TRUE 253 500
##
## Accuracy : 0.6802
## 95% CI : (0.6549, 0.7048)
## No Information Rate : 0.5007
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.3606
## Mcnemar's Test P-Value : 0.001943
##
## Sensitivity : 0.6328
## Specificity : 0.7278
## Pos Pred Value : 0.6998
## Neg Pred Value : 0.6640
## Prevalence : 0.5007
## Detection Rate : 0.3169
## Detection Prevalence : 0.4528
## Balanced Accuracy : 0.6803
##
## 'Positive' Class : FALSE
##
datTest$logRegPred <- predict(train.glm, newdata=datTest, type="response")
datTest$logRegPred0.5 <- datTest$logRegPred>0.5
(predTab <- table(datTest$logRegPred0.5, datTest$aline_flg==1,
dnn=c("Prediction","Aline")))
## Aline
## Prediction FALSE TRUE
## FALSE 439 209
## TRUE 230 497
confusionMatrix(datTest$logRegPred0.5, datTest$aline_flg==1)
## Confusion Matrix and Statistics
##
## Reference
## Prediction FALSE TRUE
## FALSE 439 209
## TRUE 230 497
##
## Accuracy : 0.6807
## 95% CI : (0.6554, 0.7053)
## No Information Rate : 0.5135
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.3605
## Mcnemar's Test P-Value : 0.3398
##
## Sensitivity : 0.6562
## Specificity : 0.7040
## Pos Pred Value : 0.6775
## Neg Pred Value : 0.6836
## Prevalence : 0.4865
## Detection Rate : 0.3193
## Detection Prevalence : 0.4713
## Balanced Accuracy : 0.6801
##
## 'Positive' Class : FALSE
##
## baseline
train.base.glm <- glm(aline_flg ~ 1, data=datTrain, family="binomial")
datTrain$logRegBasePred <- predict(train.base.glm, newdata=datTrain, type="response")
datTrain$logRegBasePred0.5 <- datTrain$logRegBasePred>0.5
(predTab <- table(datTrain$logRegBasePred0.5, datTrain$aline_flg==1,
dnn=c("Prediction","Aline")))
## Aline
## Prediction FALSE TRUE
## FALSE 689 687
confusionMatrix(datTrain$logRegBasePred0.5, datTrain$aline_flg==1)
## Warning in confusionMatrix.default(datTrain$logRegBasePred0.5, datTrain
## $aline_flg == : Levels are not in the same order for reference and data.
## Refactoring data to match.
## Confusion Matrix and Statistics
##
## Reference
## Prediction FALSE TRUE
## FALSE 689 687
## TRUE 0 0
##
## Accuracy : 0.5007
## 95% CI : (0.474, 0.5275)
## No Information Rate : 0.5007
## P-Value [Acc > NIR] : 0.5108
##
## Kappa : 0
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 1.0000
## Specificity : 0.0000
## Pos Pred Value : 0.5007
## Neg Pred Value : NaN
## Prevalence : 0.5007
## Detection Rate : 0.5007
## Detection Prevalence : 1.0000
## Balanced Accuracy : 0.5000
##
## 'Positive' Class : FALSE
##
datTest$logRegBasePred <- predict(train.base.glm, newdata=datTest, type="response")
datTest$logRegBasePred0.5 <- datTest$logRegBasePred>0.5
(predTab <- table(datTest$logRegBasePred0.5, datTest$aline_flg==1,
dnn=c("Prediction","Aline")))
## Aline
## Prediction FALSE TRUE
## FALSE 669 706
confusionMatrix(datTest$logRegBasePred0.5, datTest$aline_flg==1)
## Warning in confusionMatrix.default(datTest$logRegBasePred0.5, datTest
## $aline_flg == : Levels are not in the same order for reference and data.
## Refactoring data to match.
## Confusion Matrix and Statistics
##
## Reference
## Prediction FALSE TRUE
## FALSE 669 706
## TRUE 0 0
##
## Accuracy : 0.4865
## 95% CI : (0.4598, 0.5133)
## No Information Rate : 0.5135
## P-Value [Acc > NIR] : 0.9785
##
## Kappa : 0
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 1.0000
## Specificity : 0.0000
## Pos Pred Value : 0.4865
## Neg Pred Value : NaN
## Prevalence : 0.4865
## Detection Rate : 0.4865
## Detection Prevalence : 1.0000
## Balanced Accuracy : 0.5000
##
## 'Positive' Class : FALSE
##
# c
## full model
predTr <- prediction(datTrain$logRegPred, datTrain$aline_flg)
perfTr <- performance(predTr,"tpr","fpr")
plot(perfTr)
text(0.6,0.2,paste0("AUC: ", round(performance(predTr,"auc")@y.values[[1]],3)))
predTe <- prediction(datTest$logRegPred, datTest$aline_flg)
perfTe <- performance(predTe,"tpr","fpr")
lines(perfTe@x.values[[1]],perfTe@y.values[[1]],col='red')
text(0.6,0.1,paste0("AUC: ", round(performance(predTe,"auc")@y.values[[1]],3)),col='red')
## baseline
predTr <- prediction(datTrain$logRegBasePred, datTrain$aline_flg)
perfTr <- performance(predTr,"tpr","fpr")
plot(perfTr)
text(0.6,0.2,paste0("AUC: ", round(performance(predTr,"auc")@y.values[[1]],3)))
predTe <- prediction(datTest$logRegBasePred, datTest$aline_flg)
perfTe <- performance(predTe,"tpr","fpr")
lines(perfTe@x.values[[1]],perfTe@y.values[[1]],col='red')
text(0.6,0.1,paste0("AUC: ", round(performance(predTe,"auc")@y.values[[1]],3)),col='red')
# d
## full model
prop.table(table(datTrain$aline_flg, cut2(datTrain$logRegPred,seq(0,1,0.1))),2)
##
## 0.0 [0.1,0.2) [0.2,0.3) [0.3,0.4) [0.4,0.5) [0.5,0.6) [0.6,0.7)
## 0 0.8500000 0.7411765 0.6514286 0.6022727 0.4512635 0.3539823
## 1 0.1500000 0.2588235 0.3485714 0.3977273 0.5487365 0.6460177
##
## [0.7,0.8) [0.8,0.9) [0.9,1.0]
## 0 0.1986301 0.1919192 0.0000000
## 1 0.8013699 0.8080808 1.0000000
gbm::calibrate.plot(datTrain$aline_flg, datTrain$logRegPred) # training
prop.table(table(datTest$aline_flg, cut2(datTest$logRegPred,seq(0,1,0.1))),2)
##
## 0.0 [0.1,0.2) [0.2,0.3) [0.3,0.4) [0.4,0.5) [0.5,0.6) [0.6,0.7)
## 0 0.8620690 0.7335423 0.6728111 0.4096386 0.3626761 0.3277311
## 1 0.1379310 0.2664577 0.3271889 0.5903614 0.6373239 0.6722689
##
## [0.7,0.8) [0.8,0.9) [0.9,1.0]
## 0 0.2764228 0.1818182 0.2000000
## 1 0.7235772 0.8181818 0.8000000
gbm::calibrate.plot(datTest$aline_flg, datTest$logRegPred) # testing
## baseline
prop.table(table(datTrain$aline_flg, cut2(datTrain$logRegBasePred,seq(0,1,0.1))),2)
##
## 0.0 0.1 0.2 0.3 [0.4,0.5) 0.5 0.6 0.7 0.8 [0.9,1.0]
## 0 0.5007267
## 1 0.4992733
#gbm::calibrate.plot(datTrain$aline_flg, datTrain$logRegBasePred) # can't be plotted
prop.table(table(datTest$aline_flg, cut2(datTest$logRegBasePred,seq(0,1,0.1))),2)
##
## 0.0 0.1 0.2 0.3 [0.4,0.5) 0.5 0.6 0.7 0.8 [0.9,1.0]
## 0 0.4865455
## 1 0.5134545
#gbm::calibrate.plot(datTest$aline_flg, datTest$logRegBasePred) # can't be plotted
- The accuracy of full model with threshold 0.5 is 0.6863, the accuracy of the baseline model is 0.4936 (it predicted all FALSE aline)
- The accuracy of full model in training dataset is 0.6802, in testing set is 0.6807. If we use baseline model, training data accuracy is 0.5007, testing dataset is 0.4865.
- Please see the above ROC plot.
- Please see the above calibration curve. The calibration curve on training data is better calibrated (closer to red diagonal line) than on testing data since the model is trained by training data. To be noticed, no calibration curve is accessible in baseline model.
There are a variety of additional methods that can be used for prediction of a variety of different outcome types. Moreover there are also several ways to evaluate model fits or tune parameters. The caret package is a flexible and powerful packages which provides a unified framework for building, evaluating and tuning models.
Thus far we have focused evaluation using a held out test set. In the examples we have worked through thus far, we have not had to choose any tuning parameters. Tuning parameters are present in many prediction/machine learning algorithms, and there usually is no good a priori way to pick which parameter will make the algorithm work best.
To help us choose, we will utilize a validation set (or rather validation sets). k-fold cross validation is a frequently used technique to help choose a tuning parameter and give a preliminary assessment of how well the data will perform on data not used to train the model on. k-fold cross validation involves:
Normally to do this manually would require you to partition the data, build the models, evaluate them, choose the best tuning parameter, and summarize the performance. The caret package lets you do this all quite easily and run it on a variety of different approaches.
We first need to tell caret how we wish to do the cross-validation. caret also lets you use a few other methods instead of cross validation, but cross validation is the most common. The trainControl function call below tells caret we wish to evaluate the models using cross-validation (“cv”), and use \(k=5\). We include the last two arguments (classProbs and summaryFunction) to allow caret to pick the best model based on area under the ROC curve.
library(caret);
cvTr <- trainControl(method="cv",number=5,classProbs = TRUE,summaryFunction=twoClassSummary)
Now we will run the training and evaluation code. This is done similarly to how you would fit a logistic regression, but using the train function, and one additional parameter, trControl, which we will pass cvTr to, which we created above.
train can be pretty picky about the types of data it allows. Best to convert binary 0/1 data to a factor with difference labels. For example,
dat$day_28_flg <- as.factor(ifelse(dat$day_28_flg==1,"Died","Survived"))
mort.tr.logit <- train(day_28_flg ~ aline_flg + age.cat + sofa_cat + service_unit2 + renal_flg + chf_flg + cad_flg + stroke_flg + mal_flg + resp_flg,data=dat,family="binomial",method="glm",trControl=cvTr,metric="ROC")
print(mort.tr.logit)
## Generalized Linear Model
##
## 2751 samples
## 10 predictor
## 2 classes: 'Died', 'Survived'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 2201, 2202, 2201, 2200, 2200
## Resampling results:
##
## ROC Sens Spec
## 0.8318544 0.3022088 0.9687247
##
##
There are no tuning parameters for this model, so the output is pretty basic. You can run summary to get information about the logistic regression model fit, and do predictions very similar to how we did it before.
summary(mort.tr.logit)
##
## Call:
## NULL
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.9417 0.2048 0.2906 0.5027 1.9002
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) 3.80050 0.22133 17.171 < 2e-16 ***
## aline_flg 0.02825 0.13478 0.210 0.83400
## `age.cat[ 50, 60)` -0.70631 0.24282 -2.909 0.00363 **
## `age.cat[ 60, 70)` -1.37080 0.22770 -6.020 1.74e-09 ***
## `age.cat[ 70, 80)` -1.81831 0.21729 -8.368 < 2e-16 ***
## `age.cat[ 80,300]` -2.58962 0.21106 -12.270 < 2e-16 ***
## `sofa_cat[ 4, 7)` -0.25631 0.14661 -1.748 0.08043 .
## `sofa_cat[ 7,14]` -0.51149 0.19158 -2.670 0.00759 **
## service_unit2NMED -0.17722 0.21414 -0.828 0.40792
## service_unit2NSURG -0.15114 0.21851 -0.692 0.48915
## service_unit2Other 0.33926 0.23511 1.443 0.14903
## service_unit2SURG 1.39823 0.39424 3.547 0.00039 ***
## service_unit2TRAUM 0.02558 0.23906 0.107 0.91480
## renal_flg -0.03357 0.24058 -0.140 0.88901
## chf_flg 0.96178 0.24702 3.894 9.88e-05 ***
## cad_flg 0.43435 0.18891 2.299 0.02149 *
## stroke_flg -1.74596 0.17918 -9.744 < 2e-16 ***
## mal_flg -0.79214 0.16578 -4.778 1.77e-06 ***
## resp_flg -0.65706 0.13896 -4.728 2.26e-06 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 2340.8 on 2750 degrees of freedom
## Residual deviance: 1737.2 on 2732 degrees of freedom
## AIC: 1775.2
##
## Number of Fisher Scoring iterations: 6
dat$mort.tr.logit.pred <- predict(mort.tr.logit,newdata=dat,type="prob")
Next, you might think about simplifying the model, using AIC or some other metric. caret will do this as well, just replace the method with glmStepAIC:
mort.tr.logitaic <- train(as.factor(day_28_flg) ~ aline_flg + age.cat + sofa_cat + service_unit2 + renal_flg + chf_flg + cad_flg + stroke_flg + mal_flg + resp_flg,data=dat,family="binomial",method="glmStepAIC",trControl=cvTr,metric="ROC",trace=0)
## Loading required package: MASS
print(mort.tr.logitaic)
## Generalized Linear Model with Stepwise Feature Selection
##
## 2751 samples
## 10 predictor
## 2 classes: 'Died', 'Survived'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 2201, 2200, 2201, 2200, 2202
## Resampling results:
##
## ROC Sens Spec
## 0.8321957 0.3141423 0.966587
##
##
summary(mort.tr.logitaic)
##
## Call:
## NULL
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.9555 0.2118 0.2888 0.5087 1.8715
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) 3.7865 0.1903 19.898 < 2e-16 ***
## `age.cat[ 50, 60)` -0.7250 0.2409 -3.009 0.002619 **
## `age.cat[ 60, 70)` -1.3966 0.2245 -6.222 4.91e-10 ***
## `age.cat[ 70, 80)` -1.8451 0.2135 -8.642 < 2e-16 ***
## `age.cat[ 80,300]` -2.6216 0.2072 -12.654 < 2e-16 ***
## `sofa_cat[ 4, 7)` -0.2540 0.1455 -1.746 0.080896 .
## `sofa_cat[ 7,14]` -0.4964 0.1903 -2.608 0.009096 **
## service_unit2Other 0.3889 0.2205 1.763 0.077862 .
## service_unit2SURG 1.4525 0.3807 3.816 0.000136 ***
## chf_flg 0.9672 0.2425 3.989 6.63e-05 ***
## cad_flg 0.4337 0.1863 2.328 0.019899 *
## stroke_flg -1.8413 0.1445 -12.740 < 2e-16 ***
## mal_flg -0.7991 0.1639 -4.875 1.09e-06 ***
## resp_flg -0.6302 0.1310 -4.810 1.51e-06 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 2340.8 on 2750 degrees of freedom
## Residual deviance: 1738.1 on 2737 degrees of freedom
## AIC: 1766.1
##
## Number of Fisher Scoring iterations: 6
Again we do not have a tuning parameter in this method, so we just get an estimate of the accuracy.
There are many other methods, and we will demonstrate a few. glmnet fits the logistic regression with penalization terms. We will just use the default setting, which has two parameters which govern the penalization (see ?glmnet for more details).
This is our first technique which has tuning parameters. The first plot below illustrates how the accuracy varies as we try different values of the tuning parameters. Typically you would try many more than the nine we did below, but this is sufficient for a start.
Additionally, a variable importance plot is also printed below. Each method calculates importance differently, and you should see varImp to see how this is done for the method you use.
mort.tr.logitglmnet <- train(as.factor(day_28_flg) ~ aline_flg + age.cat + sofa_cat + service_unit2 + renal_flg + chf_flg + cad_flg + stroke_flg + mal_flg + resp_flg,data=dat,family="binomial",method="glmnet",trControl=cvTr,metric="ROC")
## Loading required package: glmnet
## Loading required package: Matrix
## Loading required package: foreach
## Loaded glmnet 2.0-5
print(mort.tr.logitglmnet )
## glmnet
##
## 2751 samples
## 10 predictor
## 2 classes: 'Died', 'Survived'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 2202, 2200, 2201, 2200, 2201
## Resampling results across tuning parameters:
##
## alpha lambda ROC Sens Spec
## 0.10 0.0002626832 0.8345521 0.2899312 0.9678654
## 0.10 0.0026268321 0.8346514 0.2851406 0.9687219
## 0.10 0.0262683213 0.8315069 0.2348250 0.9811480
## 0.55 0.0002626832 0.8346449 0.2875502 0.9674371
## 0.55 0.0026268321 0.8348602 0.2803213 0.9712933
## 0.55 0.0262683213 0.8220185 0.2132530 0.9858608
## 1.00 0.0002626832 0.8346916 0.2875502 0.9674371
## 1.00 0.0026268321 0.8345703 0.2779116 0.9721517
## 1.00 0.0262683213 0.8103149 0.2012909 0.9858599
##
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were alpha = 0.55 and lambda
## = 0.002626832.
plot(mort.tr.logitglmnet)
plot(varImp(mort.tr.logitglmnet))
Random forests (rf) are an ensemble method which creates many simple decision trees using a technique called bagging. We can also drop the family="binomial".
mort.tr.logitrf<- train(as.factor(day_28_flg) ~ aline_flg + age.cat + sofa_cat + service_unit2 + renal_flg + chf_flg + cad_flg + stroke_flg + mal_flg + resp_flg,data=dat,method="rf",trControl=cvTr,importance=TRUE,metric="ROC")
## Loading required package: randomForest
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:Hmisc':
##
## combine
## The following object is masked from 'package:ggplot2':
##
## margin
print(mort.tr.logitrf )
## Random Forest
##
## 2751 samples
## 10 predictor
## 2 classes: 'Died', 'Survived'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 2201, 2201, 2201, 2201, 2200
## Resampling results across tuning parameters:
##
## mtry ROC Sens Spec
## 2 0.7999870 0.1295468 0.9940015
## 10 0.7908165 0.3477912 0.9468749
## 18 0.7834080 0.3428571 0.9413056
##
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
plot(mort.tr.logitrf)
plot(varImp(mort.tr.logitrf))
Stochastic gradient boosting (gbm) uses a general technique known as boosting.
mort.tr.logitgbm<- train(as.factor(day_28_flg) ~ aline_flg + age.cat + sofa_cat + service_unit2 + renal_flg + chf_flg + cad_flg + stroke_flg + mal_flg + resp_flg,data=dat,method="gbm",trControl=cvTr,verbose=FALSE,metric="ROC")
## Loading required package: gbm
## Loading required package: splines
## Loading required package: parallel
## Loaded gbm 2.1.1
print(mort.tr.logitgbm )
## Stochastic Gradient Boosting
##
## 2751 samples
## 10 predictor
## 2 classes: 'Died', 'Survived'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 2201, 2200, 2201, 2201, 2201
## Resampling results across tuning parameters:
##
## interaction.depth n.trees ROC Sens Spec
## 1 50 0.8163341 0.2133678 0.9832875
## 1 100 0.8262332 0.2588927 0.9747121
## 1 150 0.8301358 0.2733219 0.9717152
## 2 50 0.8286889 0.2444349 0.9789966
## 2 100 0.8317131 0.3043603 0.9682845
## 2 150 0.8324614 0.3163798 0.9704295
## 3 50 0.8316223 0.2806081 0.9747176
## 3 100 0.8347733 0.3092083 0.9704276
## 3 150 0.8364044 0.3284567 0.9695711
##
## Tuning parameter 'shrinkage' was held constant at a value of 0.1
##
## Tuning parameter 'n.minobsinnode' was held constant at a value of 10
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees = 150,
## interaction.depth = 3, shrinkage = 0.1 and n.minobsinnode = 10.
plot(mort.tr.logitgbm)
plot(varImp(mort.tr.logitgbm))
The default settings optimizes on the accuracy in the validation sets, and picks a value based on those. The last step fits a model with the optimal tuning parameters with the complete dataset. To access the predictions from this fit, we add new columns to the dat data frame:
dat$predMort1 <- predict(mort.tr.logit,type="prob")[,2]
dat$predMort2 <- predict(mort.tr.logitglmnet,type="prob")[,2]
dat$predMort3 <- predict(mort.tr.logitrf,type="prob")[,2]
A complete list of methods you can use with caret is available:
https://topepo.github.io/caret/modelList.html
- Setup
caretto fit a logistic regression model foraline_flg, similar to what you did earlier. How do these results compare with what you did before? Hint: Make sure youraline_flgvariable is a factor (and doesn’t have 0/1 levels) for this part and all other parts below (see above for an example).
- Using
method="rf", set an additional argument intraintotuneLength=5, and run a random forest model foraline_flg.
- Go to the link above. Pick one additional method not used thus far (note: the method should either be of type “Classification” or “Dual Use”). Ones checked to work:
nnet,dnn,knn,xgbLinearandsvmLinear2, and use the default tuning length. Comment on the performance. (Hint:verbose=FALSEortrace=FALSEsuppresses the noisy output while fitting for some models. If this fails, try wrapping the entiretraincall withsuppressMessages.svmLinear2needs also to add:probability = TRUE, adding maxit = 300 tonnetwill make sure it fits properly.)
- For a), b), and c) create a new column in the
datdata frame for predictions from each method. Using these predictions plot a ROC curves and compute the AUCs. Try to put all lines on the same plot. If you have difficulties, 3 separate plots is OK. Why do the AUCs differ from those parts a)-c)?
- Create a calibration plot for predictions in parts a)-c). How well calibrated is each method?
- Discuss which method had the best performance. Imagine (hypothetically) you would like to develop an algorithm that would suggest an arterial line for types of patients who frequently get them (ignore the effectiveness of the procedure). What additional steps might you want to do before deploying an algorithm “in the wild”?
An aside: Some methods in part c) will perform badly in this dataset and default tuning parameters. This does not mean that they are generally a bad method, and might not even be a bad method for this dataset, with the right tuning parameters. It just goes to show that although these methods cannot be applied blindly, and sometimes finding the best way to construct these prediction algorithms is as much an art as a science.
# a
dat <- read.csv("aline-dataset.csv")
dat$sofa_cat <- cut2(dat$sofa_first,c(0,4,7))
dat$age.cat <- cut2(dat$age,c(50,60,70,80))
dat$service_unit2 <- as.character(dat$service_unit)
dat$service_unit2[dat$service_unit2 %in% names(which(table(dat$service_unit)<200))] <- "Other"
cvTr <- trainControl(method="cv", number=5, classProbs=TRUE,
summaryFunction=twoClassSummary)
dat$aline_flg <- as.factor(ifelse(dat$aline_flg==1, "Aline", "No_Aline"))
aline.tr.logit <- train(aline_flg ~ age.cat + sofa_cat + service_unit2 + renal_flg + chf_flg + cad_flg + stroke_flg + mal_flg + resp_flg,
data=dat, family="binomial", method="glm",
trControl=cvTr, metric="ROC")
print(aline.tr.logit)
## Generalized Linear Model
##
## 2751 samples
## 9 predictor
## 2 classes: 'Aline', 'No_Aline'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 2201, 2200, 2201, 2201, 2201
## Resampling results:
##
## ROC Sens Spec
## 0.7240021 0.7343648 0.6354732
##
##
#plot(aline.tr.logit)
plot(varImp(aline.tr.logit))
# b
aline.tr.rf <- train(aline_flg ~ age.cat + sofa_cat + service_unit2 + renal_flg + chf_flg + cad_flg + stroke_flg + mal_flg + resp_flg,
data=dat, method="rf", tuneLength=5,
trControl=cvTr, metric="ROC")
print(aline.tr.rf)
## Random Forest
##
## 2751 samples
## 9 predictor
## 2 classes: 'Aline', 'No_Aline'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 2202, 2200, 2200, 2201, 2201
## Resampling results across tuning parameters:
##
## mtry ROC Sens Spec
## 2 0.7095592 0.6956370 0.6613062
## 5 0.7049545 0.7049638 0.6406935
## 9 0.6973995 0.6884660 0.6362709
## 13 0.6932429 0.6812872 0.6333189
## 17 0.6894750 0.6741162 0.6325781
##
## ROC was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
plot(aline.tr.rf)
plot(varImp(aline.tr.rf))
# c
aline.tr.nnet <- train(aline_flg ~ age.cat + sofa_cat + service_unit2 + renal_flg + chf_flg + cad_flg + stroke_flg + mal_flg + resp_flg,
data=dat, method="nnet",
trControl=cvTr, metric="ROC")
## Loading required package: nnet
## # weights: 20
## initial value 1556.240554
## iter 10 value 1343.785180
## iter 20 value 1334.555721
## iter 30 value 1333.443851
## iter 40 value 1333.343986
## iter 50 value 1333.331306
## iter 60 value 1333.322136
## iter 70 value 1333.314203
## final value 1333.313859
## converged
## # weights: 58
## initial value 1553.099572
## iter 10 value 1342.957718
## iter 20 value 1330.741099
## iter 30 value 1319.440036
## iter 40 value 1294.475358
## iter 50 value 1288.308018
## iter 60 value 1283.710593
## iter 70 value 1278.989963
## iter 80 value 1276.683979
## iter 90 value 1276.073263
## iter 100 value 1275.350330
## final value 1275.350330
## stopped after 100 iterations
## # weights: 96
## initial value 1526.866344
## iter 10 value 1365.071273
## iter 20 value 1297.316284
## iter 30 value 1277.023957
## iter 40 value 1249.146658
## iter 50 value 1240.798080
## iter 60 value 1232.951193
## iter 70 value 1226.045271
## iter 80 value 1221.202666
## iter 90 value 1216.722875
## iter 100 value 1212.199843
## final value 1212.199843
## stopped after 100 iterations
## # weights: 20
## initial value 1524.229647
## iter 10 value 1342.269890
## iter 20 value 1338.064168
## iter 30 value 1337.137693
## iter 40 value 1336.909644
## final value 1336.908008
## converged
## # weights: 58
## initial value 1546.227765
## iter 10 value 1338.946842
## iter 20 value 1316.511153
## iter 30 value 1308.677358
## iter 40 value 1304.189785
## iter 50 value 1301.884402
## iter 60 value 1298.274678
## iter 70 value 1295.096922
## iter 80 value 1294.415791
## iter 90 value 1294.295856
## iter 100 value 1294.281560
## final value 1294.281560
## stopped after 100 iterations
## # weights: 96
## initial value 1547.205489
## iter 10 value 1397.862908
## iter 20 value 1329.900531
## iter 30 value 1312.955930
## iter 40 value 1304.772010
## iter 50 value 1297.735940
## iter 60 value 1291.495720
## iter 70 value 1286.463177
## iter 80 value 1283.666871
## iter 90 value 1282.451182
## iter 100 value 1282.027388
## final value 1282.027388
## stopped after 100 iterations
## # weights: 20
## initial value 1532.322609
## iter 10 value 1393.962671
## iter 20 value 1334.235306
## iter 30 value 1332.779257
## iter 40 value 1332.623347
## iter 50 value 1332.607897
## final value 1332.600434
## converged
## # weights: 58
## initial value 1808.629629
## iter 10 value 1362.198412
## iter 20 value 1336.705018
## iter 30 value 1325.690336
## iter 40 value 1314.237023
## iter 50 value 1303.042885
## iter 60 value 1297.985923
## iter 70 value 1295.203497
## iter 80 value 1294.713807
## iter 90 value 1294.351809
## iter 100 value 1294.025598
## final value 1294.025598
## stopped after 100 iterations
## # weights: 96
## initial value 1596.104338
## iter 10 value 1335.875142
## iter 20 value 1307.352863
## iter 30 value 1283.456393
## iter 40 value 1264.668165
## iter 50 value 1252.081947
## iter 60 value 1246.224071
## iter 70 value 1241.840297
## iter 80 value 1233.578587
## iter 90 value 1228.540944
## iter 100 value 1225.539617
## final value 1225.539617
## stopped after 100 iterations
## # weights: 20
## initial value 1782.356612
## iter 10 value 1375.229664
## iter 20 value 1341.073033
## iter 30 value 1334.679543
## iter 40 value 1334.025544
## iter 50 value 1334.004108
## iter 60 value 1333.998480
## iter 70 value 1333.977629
## final value 1333.976848
## converged
## # weights: 58
## initial value 1535.050433
## iter 10 value 1332.207692
## iter 20 value 1317.352950
## iter 30 value 1308.029713
## iter 40 value 1298.686530
## iter 50 value 1292.265673
## iter 60 value 1288.266467
## iter 70 value 1285.066215
## iter 80 value 1283.471084
## iter 90 value 1282.224060
## iter 100 value 1281.521756
## final value 1281.521756
## stopped after 100 iterations
## # weights: 96
## initial value 1691.566892
## iter 10 value 1524.749637
## iter 20 value 1513.829965
## iter 30 value 1415.900428
## iter 40 value 1388.216390
## iter 50 value 1355.170689
## iter 60 value 1355.003113
## iter 70 value 1353.055686
## iter 80 value 1344.496230
## iter 90 value 1330.916723
## iter 100 value 1322.760222
## final value 1322.760222
## stopped after 100 iterations
## # weights: 20
## initial value 1524.242447
## iter 10 value 1355.663979
## iter 20 value 1340.071825
## iter 30 value 1339.116554
## iter 40 value 1339.017375
## final value 1339.017148
## converged
## # weights: 58
## initial value 1568.829711
## iter 10 value 1392.168429
## iter 20 value 1325.994461
## iter 30 value 1317.562950
## iter 40 value 1315.360953
## iter 50 value 1312.639910
## iter 60 value 1310.255618
## iter 70 value 1309.688013
## iter 80 value 1309.280652
## iter 90 value 1309.194150
## iter 100 value 1308.929403
## final value 1308.929403
## stopped after 100 iterations
## # weights: 96
## initial value 1541.998870
## iter 10 value 1372.515136
## iter 20 value 1329.101436
## iter 30 value 1315.214229
## iter 40 value 1303.804746
## iter 50 value 1298.063246
## iter 60 value 1293.550511
## iter 70 value 1289.369630
## iter 80 value 1287.363295
## iter 90 value 1286.619552
## iter 100 value 1286.061262
## final value 1286.061262
## stopped after 100 iterations
## # weights: 20
## initial value 1606.231159
## iter 10 value 1401.532875
## iter 20 value 1335.910428
## iter 30 value 1334.463304
## iter 40 value 1333.999781
## final value 1333.995030
## converged
## # weights: 58
## initial value 1626.515899
## iter 10 value 1350.318080
## iter 20 value 1318.919412
## iter 30 value 1312.728050
## iter 40 value 1310.861417
## iter 50 value 1309.066076
## iter 60 value 1299.530881
## iter 70 value 1294.054137
## iter 80 value 1290.560752
## iter 90 value 1289.333053
## iter 100 value 1289.135039
## final value 1289.135039
## stopped after 100 iterations
## # weights: 96
## initial value 1866.879696
## iter 10 value 1348.351484
## iter 20 value 1322.553253
## iter 30 value 1304.802529
## iter 40 value 1282.396298
## iter 50 value 1269.547217
## iter 60 value 1263.117167
## iter 70 value 1254.301508
## iter 80 value 1248.387844
## iter 90 value 1245.554451
## iter 100 value 1244.116581
## final value 1244.116581
## stopped after 100 iterations
## # weights: 20
## initial value 1529.205734
## iter 10 value 1332.730766
## iter 20 value 1320.898239
## iter 30 value 1315.264171
## iter 40 value 1315.063040
## iter 50 value 1314.385208
## iter 60 value 1314.226374
## final value 1314.226050
## converged
## # weights: 58
## initial value 1605.019909
## iter 10 value 1358.556467
## iter 20 value 1304.871653
## iter 30 value 1292.182398
## iter 40 value 1277.427937
## iter 50 value 1267.533257
## iter 60 value 1264.526986
## iter 70 value 1262.818082
## iter 80 value 1259.708277
## iter 90 value 1254.718066
## iter 100 value 1252.215891
## final value 1252.215891
## stopped after 100 iterations
## # weights: 96
## initial value 1526.250631
## iter 10 value 1316.977000
## iter 20 value 1288.718243
## iter 30 value 1268.412307
## iter 40 value 1240.526680
## iter 50 value 1223.172226
## iter 60 value 1211.174692
## iter 70 value 1203.114384
## iter 80 value 1197.808292
## iter 90 value 1188.284414
## iter 100 value 1184.439254
## final value 1184.439254
## stopped after 100 iterations
## # weights: 20
## initial value 1592.812005
## iter 10 value 1519.293631
## iter 20 value 1483.436947
## iter 30 value 1465.696458
## iter 40 value 1464.825669
## iter 50 value 1361.180989
## iter 60 value 1328.001989
## iter 70 value 1320.461260
## iter 80 value 1319.037682
## iter 90 value 1318.988543
## final value 1318.988190
## converged
## # weights: 58
## initial value 1773.190212
## iter 10 value 1364.005412
## iter 20 value 1324.219109
## iter 30 value 1312.792318
## iter 40 value 1305.486935
## iter 50 value 1299.969919
## iter 60 value 1298.604085
## iter 70 value 1295.642501
## iter 80 value 1291.498153
## iter 90 value 1289.412407
## iter 100 value 1287.925572
## final value 1287.925572
## stopped after 100 iterations
## # weights: 96
## initial value 1610.624277
## iter 10 value 1332.777628
## iter 20 value 1311.351253
## iter 30 value 1294.605506
## iter 40 value 1283.230239
## iter 50 value 1274.934173
## iter 60 value 1272.927256
## iter 70 value 1270.938883
## iter 80 value 1269.011612
## iter 90 value 1266.657177
## iter 100 value 1264.035357
## final value 1264.035357
## stopped after 100 iterations
## # weights: 20
## initial value 1529.583723
## iter 10 value 1321.291668
## iter 20 value 1315.848883
## iter 30 value 1314.383735
## iter 40 value 1314.231567
## final value 1314.231354
## converged
## # weights: 58
## initial value 1574.210763
## iter 10 value 1348.785514
## iter 20 value 1317.888662
## iter 30 value 1313.589155
## iter 40 value 1308.694915
## iter 50 value 1302.037310
## iter 60 value 1289.647311
## iter 70 value 1285.885709
## iter 80 value 1284.194768
## iter 90 value 1281.507836
## iter 100 value 1280.668850
## final value 1280.668850
## stopped after 100 iterations
## # weights: 96
## initial value 1576.622071
## iter 10 value 1326.316610
## iter 20 value 1298.493382
## iter 30 value 1279.101983
## iter 40 value 1255.254983
## iter 50 value 1241.637356
## iter 60 value 1232.451333
## iter 70 value 1226.181557
## iter 80 value 1222.553686
## iter 90 value 1218.737863
## iter 100 value 1217.012509
## final value 1217.012509
## stopped after 100 iterations
## # weights: 20
## initial value 1588.287063
## iter 10 value 1350.890799
## iter 20 value 1340.792708
## iter 30 value 1336.894025
## iter 40 value 1334.374175
## iter 50 value 1331.721305
## iter 60 value 1330.381133
## iter 70 value 1330.077345
## iter 80 value 1330.036505
## iter 90 value 1329.850631
## final value 1329.841575
## converged
## # weights: 58
## initial value 1581.231904
## iter 10 value 1330.278159
## iter 20 value 1312.426818
## iter 30 value 1299.598651
## iter 40 value 1284.233767
## iter 50 value 1273.255245
## iter 60 value 1268.807389
## iter 70 value 1266.080879
## iter 80 value 1260.454408
## iter 90 value 1258.173935
## iter 100 value 1257.898864
## final value 1257.898864
## stopped after 100 iterations
## # weights: 96
## initial value 1605.685860
## iter 10 value 1338.666023
## iter 20 value 1311.714169
## iter 30 value 1292.592424
## iter 40 value 1259.165419
## iter 50 value 1249.971172
## iter 60 value 1237.114142
## iter 70 value 1224.757426
## iter 80 value 1218.511036
## iter 90 value 1215.340597
## iter 100 value 1213.144548
## final value 1213.144548
## stopped after 100 iterations
## # weights: 20
## initial value 1712.944260
## iter 10 value 1392.774637
## iter 20 value 1358.084186
## iter 30 value 1344.007006
## iter 40 value 1341.549219
## iter 50 value 1340.464884
## iter 60 value 1340.227047
## final value 1340.227026
## converged
## # weights: 58
## initial value 1533.708776
## iter 10 value 1371.998471
## iter 20 value 1321.784936
## iter 30 value 1309.959171
## iter 40 value 1303.615237
## iter 50 value 1302.751870
## iter 60 value 1302.324318
## iter 70 value 1302.062917
## iter 80 value 1301.996615
## iter 90 value 1301.982917
## iter 100 value 1301.673261
## final value 1301.673261
## stopped after 100 iterations
## # weights: 96
## initial value 1664.237119
## iter 10 value 1358.190463
## iter 20 value 1327.125664
## iter 30 value 1311.438636
## iter 40 value 1304.429401
## iter 50 value 1299.190473
## iter 60 value 1297.177684
## iter 70 value 1295.211824
## iter 80 value 1290.101914
## iter 90 value 1286.644210
## iter 100 value 1285.278859
## final value 1285.278859
## stopped after 100 iterations
## # weights: 20
## initial value 1601.251085
## iter 10 value 1415.187139
## iter 20 value 1342.556100
## iter 30 value 1337.585397
## iter 40 value 1337.076563
## iter 50 value 1336.933400
## iter 60 value 1336.762248
## iter 70 value 1336.587030
## iter 80 value 1336.499064
## iter 90 value 1336.203438
## iter 100 value 1336.071030
## final value 1336.071030
## stopped after 100 iterations
## # weights: 58
## initial value 1529.362813
## iter 10 value 1371.374666
## iter 20 value 1316.166759
## iter 30 value 1309.031143
## iter 40 value 1301.939926
## iter 50 value 1296.087628
## iter 60 value 1294.041301
## iter 70 value 1292.934025
## iter 80 value 1292.488932
## iter 90 value 1292.175977
## iter 100 value 1292.094513
## final value 1292.094513
## stopped after 100 iterations
## # weights: 96
## initial value 1619.587846
## iter 10 value 1322.736220
## iter 20 value 1309.727311
## iter 30 value 1289.564105
## iter 40 value 1268.662112
## iter 50 value 1252.990271
## iter 60 value 1245.318274
## iter 70 value 1238.104808
## iter 80 value 1233.944849
## iter 90 value 1231.276806
## iter 100 value 1229.476627
## final value 1229.476627
## stopped after 100 iterations
## # weights: 20
## initial value 1672.711808
## iter 10 value 1433.441959
## iter 20 value 1364.540751
## iter 30 value 1350.596049
## iter 40 value 1344.404744
## iter 50 value 1344.049254
## iter 60 value 1340.584714
## iter 70 value 1337.775386
## iter 80 value 1333.069110
## iter 90 value 1332.593500
## iter 100 value 1332.462269
## final value 1332.462269
## stopped after 100 iterations
## # weights: 58
## initial value 1552.180650
## iter 10 value 1331.818714
## iter 20 value 1307.588100
## iter 30 value 1302.100299
## iter 40 value 1297.223993
## iter 50 value 1286.804018
## iter 60 value 1277.740396
## iter 70 value 1270.707686
## iter 80 value 1261.953193
## iter 90 value 1259.479727
## iter 100 value 1258.009640
## final value 1258.009640
## stopped after 100 iterations
## # weights: 96
## initial value 1555.294618
## iter 10 value 1329.217790
## iter 20 value 1296.818660
## iter 30 value 1261.521969
## iter 40 value 1240.439712
## iter 50 value 1230.993817
## iter 60 value 1224.940883
## iter 70 value 1217.581320
## iter 80 value 1210.300788
## iter 90 value 1207.860320
## iter 100 value 1206.620330
## final value 1206.620330
## stopped after 100 iterations
## # weights: 20
## initial value 1567.340004
## iter 10 value 1355.667408
## iter 20 value 1336.902848
## iter 30 value 1336.642217
## iter 40 value 1336.407073
## iter 40 value 1336.407067
## iter 40 value 1336.407067
## final value 1336.407067
## converged
## # weights: 58
## initial value 1740.842841
## iter 10 value 1383.287611
## iter 20 value 1343.542873
## iter 30 value 1320.582136
## iter 40 value 1308.023927
## iter 50 value 1302.810120
## iter 60 value 1299.799563
## iter 70 value 1298.673627
## iter 80 value 1298.411242
## iter 90 value 1298.318802
## iter 100 value 1298.281406
## final value 1298.281406
## stopped after 100 iterations
## # weights: 96
## initial value 1568.320204
## iter 10 value 1363.405058
## iter 20 value 1320.375434
## iter 30 value 1306.811496
## iter 40 value 1295.916842
## iter 50 value 1290.919911
## iter 60 value 1285.885767
## iter 70 value 1283.725237
## iter 80 value 1281.248643
## iter 90 value 1280.096502
## iter 100 value 1279.007727
## final value 1279.007727
## stopped after 100 iterations
## # weights: 20
## initial value 1534.978598
## iter 10 value 1364.565116
## iter 20 value 1333.366807
## iter 30 value 1332.475328
## final value 1332.173315
## converged
## # weights: 58
## initial value 1671.894969
## final value 1524.880171
## converged
## # weights: 96
## initial value 1526.933519
## iter 10 value 1328.966628
## iter 20 value 1298.446543
## iter 30 value 1275.685085
## iter 40 value 1253.164481
## iter 50 value 1240.190941
## iter 60 value 1232.835994
## iter 70 value 1224.569382
## iter 80 value 1219.599985
## iter 90 value 1216.854761
## iter 100 value 1215.725080
## final value 1215.725080
## stopped after 100 iterations
## # weights: 20
## initial value 1956.362064
## iter 10 value 1752.826408
## iter 20 value 1674.977260
## iter 30 value 1669.385752
## iter 40 value 1669.109301
## final value 1669.102053
## converged
print(aline.tr.nnet)
## Neural Network
##
## 2751 samples
## 9 predictor
## 2 classes: 'Aline', 'No_Aline'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 2201, 2200, 2202, 2201, 2200
## Resampling results across tuning parameters:
##
## size decay ROC Sens Spec
## 1 0e+00 0.7106178 0.7271989 0.6354515
## 1 1e-04 0.7212694 0.7279157 0.6428099
## 1 1e-01 0.7225486 0.7300663 0.6376547
## 3 0e+00 0.7095568 0.7149945 0.6369411
## 3 1e-04 0.6805661 0.7730564 0.4927664
## 3 1e-01 0.7204115 0.7142828 0.6516687
## 5 0e+00 0.6893622 0.7042186 0.6251329
## 5 1e-04 0.6963404 0.7078079 0.6037796
## 5 1e-01 0.7145812 0.7042392 0.6376655
##
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were size = 1 and decay = 0.1.
plot(aline.tr.nnet)
plot(varImp(aline.tr.nnet))
# d
dat$predLogit <- predict(aline.tr.logit, type="prob")[,2]
dat$predRf <- predict(aline.tr.rf, type="prob")[,2]
dat$predNNet <- predict(aline.tr.nnet, type="prob")[,2]
predL <- prediction(dat$predLogit, dat$aline_flg)
perfL <- performance(predL, "tpr", "fpr")
plot(perfL)
text(0.6, 0.3, paste0("AUC: ", round(performance(predL, "auc")@y.values[[1]],3)))
predR <- prediction(dat$predRf, dat$aline_flg)
perfR <- performance(predR, "tpr", "fpr")
lines(perfR@x.values[[1]], perfR@y.values[[1]], col='red')
text(0.6, 0.2, paste0("AUC: ", round(performance(predR, "auc")@y.values[[1]],3)), col="red")
predN <- prediction(dat$predNNet, dat$aline_flg)
perfN <- performance(predN, "tpr", "fpr")
lines(perfN@x.values[[1]], perfN@y.values[[1]], col='blue')
text(0.6, 0.1, paste0("AUC: ", round(performance(predN, "auc")@y.values[[1]],3)), col="blue")
# e
prop.table(table(dat$aline_flg, cut2(dat$predLogit, seq(0,1,0.1))), 2)
##
## [0.0,0.1) [0.1,0.2) [0.2,0.3) [0.3,0.4) [0.4,0.5) [0.5,0.6)
## Aline 1.0000000 0.8171429 0.7493188 0.6381381 0.5789474 0.4520548
## No_Aline 0.0000000 0.1828571 0.2506812 0.3618619 0.4210526 0.5479452
##
## [0.6,0.7) [0.7,0.8) [0.8,0.9) [0.9,1.0]
## Aline 0.3059701 0.2651391 0.1833333
## No_Aline 0.6940299 0.7348609 0.8166667
gbm::calibrate.plot(dat$aline_flg, dat$predLogit)
prop.table(table(dat$aline_flg, cut2(dat$predRf, seq(0,1,0.1))), 2)
##
## [0.0,0.1) [0.1,0.2) [0.2,0.3) [0.3,0.4) [0.4,0.5) [0.5,0.6)
## Aline 0.8857143 0.8697917 0.6848592 0.7012712 0.5671642 0.4222222
## No_Aline 0.1142857 0.1302083 0.3151408 0.2987288 0.4328358 0.5777778
##
## [0.6,0.7) [0.7,0.8) [0.8,0.9) [0.9,1.0]
## Aline 0.3565217 0.3395062 0.2600619 0.2262774
## No_Aline 0.6434783 0.6604938 0.7399381 0.7737226
gbm::calibrate.plot(dat$aline_flg, dat$predRf)
prop.table(table(dat$aline_flg, cut2(dat$predNNet, seq(0,1,0.1))), 2)
##
## 0.0 [0.1,0.2) [0.2,0.3) [0.3,0.4) [0.4,0.5) [0.5,0.6) [0.6,0.7)
## Aline 0.8109756 0.7667560 0.6347826 0.5754386 0.4774194 0.3050398
## No_Aline 0.1890244 0.2332440 0.3652174 0.4245614 0.5225806 0.6949602
##
## [0.7,0.8) [0.8,0.9) [0.9,1.0]
## Aline 0.2632375 0.1956522
## No_Aline 0.7367625 0.8043478
gbm::calibrate.plot(dat$aline_flg, dat$predNNet)
- Please check the above results of random forest model.
- Please check the above results of neural network model.
- Different algorithms will give us different accuracy, sensitivity, specificity, PPV and NPV. Therefore there AUC should also be different.
- The neural network model has better calibration than logisitic regression, followed by random forest. The better calibrated model is smoother and closer to red diagonal line.
- Random forest provides the best performance based on AUC. However, we still need to check if the model does not have the problem of overfitting before deploying the model to the real world. The error analysis, portability test (using more data or other ICU data) and interpretability should be checked. We need to analyze whether the top important features in our model make sense and interpretable. Otherwise it’s hard to implement into the real ICU setting without the convincing model.